{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ECG Rhythm Classification\n",
    "## 1. Train Model\n",
    "### Sebastian D. Goodfellow, Ph.D."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Noteboook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 3rd party libraries\n",
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "# Local Libraries\n",
    "sys.path.insert(0, r'C:\\Users\\sebastian goodfellow\\Documents\\code\\deep_ecg')\n",
    "from deepecg.training.utils.plotting.training_data_validation import interval_plot_interact\n",
    "from deepecg.training.utils.devices.device_check import print_device_counts\n",
    "from deepecg.training.train.disc.data_generator import DataGenerator\n",
    "from deepecg.training.train.disc.train import train\n",
    "from deepecg.training.model.disc.model import Model\n",
    "from deepecg.config.config import DATA_DIR\n",
    "\n",
    "# Configure Notebook\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Resources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Objective Function\n",
    "# https://stackoverflow.com/questions/44560549/unbalanced-data-and-weighted-cross-entropy\n",
    "\n",
    "# Global Average Pooling\n",
    "# https://alexisbcook.github.io/2017/global-average-pooling-layers-for-object-localization/\n",
    "# https://github.com/philipperemy/tensorflow-class-activation-mapping/blob/master/class_activation_map.py\n",
    "# https://github.com/AndersonJo/global-average-pooling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. View Training Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cb7f3152f93a45c0b6011d050ef9581e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(IntSlider(value=852, description='label_id', max=1705), Output()), _dom_classes=('widget…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Data path\n",
    "data_path = os.path.join(DATA_DIR, 'training')\n",
    "\n",
    "# Launch plot\n",
    "interval_plot_interact(path=data_path, dataset='val')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Test Data Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<PrefetchDataset shapes: ((?, 12000, 1), (?,)), types: (tf.float32, tf.int32)>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Initialize generator\n",
    "generator = DataGenerator(path=data_path, mode='train', shape=[12000, 1], batch_size=32, \n",
    "                          prefetch_buffer=1000, seed=0, num_parallel_calls=24)\n",
    "\n",
    "# View dataset\n",
    "generator.dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. Initialize Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Workstation has 1 CPUs.\n",
      "Workstation has 2 GPUs.\n",
      "input | shape = (?, 12000, 1)\n",
      "layer_1 | shape = (?, 12000, 128)\n",
      "layer_2_res | shape = (?, 12000, 128)\n",
      "layer_2_skip | shape = (?, 12000, 128)\n",
      "layer_3_res | shape = (?, 12000, 128)\n",
      "layer_3_skip | shape = (?, 12000, 128)\n",
      "layer_4_res | shape = (?, 12000, 128)\n",
      "layer_4_skip | shape = (?, 12000, 128)\n",
      "layer_5_res | shape = (?, 12000, 128)\n",
      "layer_5_skip | shape = (?, 12000, 128)\n",
      "layer_6_res | shape = (?, 12000, 128)\n",
      "layer_6_skip | shape = (?, 12000, 128)\n",
      "layer_7_res | shape = (?, 12000, 128)\n",
      "layer_7_skip | shape = (?, 12000, 128)\n",
      "layer_8_res | shape = (?, 12000, 128)\n",
      "layer_8_skip | shape = (?, 12000, 128)\n",
      "layer_9_res | shape = (?, 12000, 128)\n",
      "layer_9_skip | shape = (?, 12000, 128)\n",
      "layer_10_skip | shape = (?, 12000, 128)\n",
      "output_skip_addition | shape = (?, 12000, 128)\n",
      "output_conv1 | shape = (?, 12000, 256)\n",
      "output_conv2 | shape = (?, 12000, 512)\n",
      "gap | shape = (?, 512)\n",
      "logits | shape = (?, 4)\n",
      "input | shape = (?, 12000, 1)\n",
      "layer_1 | shape = (?, 12000, 128)\n",
      "layer_2_res | shape = (?, 12000, 128)\n",
      "layer_2_skip | shape = (?, 12000, 128)\n",
      "layer_3_res | shape = (?, 12000, 128)\n",
      "layer_3_skip | shape = (?, 12000, 128)\n",
      "layer_4_res | shape = (?, 12000, 128)\n",
      "layer_4_skip | shape = (?, 12000, 128)\n",
      "layer_5_res | shape = (?, 12000, 128)\n",
      "layer_5_skip | shape = (?, 12000, 128)\n",
      "layer_6_res | shape = (?, 12000, 128)\n",
      "layer_6_skip | shape = (?, 12000, 128)\n",
      "layer_7_res | shape = (?, 12000, 128)\n",
      "layer_7_skip | shape = (?, 12000, 128)\n",
      "layer_8_res | shape = (?, 12000, 128)\n",
      "layer_8_skip | shape = (?, 12000, 128)\n",
      "layer_9_res | shape = (?, 12000, 128)\n",
      "layer_9_skip | shape = (?, 12000, 128)\n",
      "layer_10_skip | shape = (?, 12000, 128)\n",
      "output_skip_addition | shape = (?, 12000, 128)\n",
      "output_conv1 | shape = (?, 12000, 256)\n",
      "output_conv2 | shape = (?, 12000, 512)\n",
      "gap | shape = (?, 512)\n",
      "logits | shape = (?, 4)\n"
     ]
    }
   ],
   "source": [
    "# Print devices\n",
    "print_device_counts()\n",
    "\n",
    "# Set save path for graphs, summaries, and checkpoints\n",
    "save_path = r'C:\\Users\\sebastian goodfellow\\Desktop\\tensorboard\\deep_ecg\\wavenet_2'\n",
    "\n",
    "# Set model name\n",
    "model_name = 'test_7'\n",
    "\n",
    "# Maximum number of checkpoints to keep\n",
    "max_to_keep = 20\n",
    "\n",
    "# Set randome states\n",
    "seed = 0                                                         \n",
    "\n",
    "# Get training dataset dimensions\n",
    "length, channels = (12000, 1)     \n",
    "\n",
    "# Number of classes\n",
    "classes = 4\n",
    "\n",
    "# Choose network\n",
    "network_name = 'DeepECGV7'\n",
    "\n",
    "# Set network inputs\n",
    "network_parameters = {'length': length, 'channels': channels, 'classes': classes, 'seed': seed}\n",
    "\n",
    "# Create model\n",
    "model = Model(model_name=model_name, \n",
    "              network_name=network_name, \n",
    "              network_parameters=network_parameters, \n",
    "              save_path=save_path,\n",
    "              data_path=data_path,\n",
    "              max_to_keep=max_to_keep)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4. Train Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Set hyper-parameters\n",
    "epochs = 200\n",
    "batch_size = 16\n",
    "learning_rate_start = 0.001            \n",
    "\n",
    "# Train model\n",
    "train(model=model, epochs=epochs, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
